'''
Author: 梦付千秋星垂野 465943794@qq.com
Date: 2023-03-27 15:26:47
LastEditors: 梦付千秋星垂野 465943794@qq.com
LastEditTime: 2023-03-27 17:13:11
FilePath: /base_machinelearning/DigitRecognizer/predit.py
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
'''
import pandas as pd
import torch
from utils.DataLoader import MNIST_data
from torchvision import transforms
from model.AlexNet import Net
from torch.autograd import Variable
import numpy as np
batch_size = 64
n_pixels=784
test_dataset = MNIST_data('../datasets/digit-recognizer/test.csv',n_pixels=n_pixels)


test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, shuffle=False)
model=Net() 
model.load_state_dict(torch.load('epoch50-AlexNet.pth'))
if torch.cuda.is_available():
    model = model.cuda()
def prediciton(data_loader):
    model.eval()
    test_pred = torch.LongTensor()
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            if torch.cuda.is_available():
                data = data.cuda()
            
            output = model(data)
        
            pred = output.cpu().data.max(1, keepdim=True)[1]
            test_pred = torch.cat((test_pred, pred), dim=0)
    return test_pred
test_pred = prediciton(test_loader)
out_df = pd.DataFrame(np.c_[np.arange(1, len(test_dataset)+1)[:,None], test_pred.numpy()], 
                      columns=['ImageId', 'Label'])
out_df.to_csv('submission_AlexNet_50.csv', index=False)